{% extends 'base.exported.class' %}

{% block content %}
package main

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"os"
	"strconv"
)

type {{ class_name }} struct {
	Lefts []int `json:"lefts"`
	Rights []int `json:"rights"`
	Thresholds []float64 `json:"thresholds"`
	Indices []int `json:"indices"`
	Classes [][]int `json:"classes"`
}

{% if is_test or to_json %}
type Response struct {
	Predict int `json:"predict"`
	PredictProba []float64 `json:"predict_proba"`
}

{% endif %}
func findMax(nums []int) int {
	var idx = 0
	for i := 0; i < len(nums); i++ {
		if nums[i] > nums[idx] {
			idx = i
		}
	}
	return idx
}

func normVals(nums []int) []float64 {
	var l = len(nums)
	var sum = 0
	for i := 0; i < l; i++ {
		sum += nums[i]
	}
	result := make([]float64, l)
	if sum == 0 {
		for i := 0; i < l; i++ {
			result[i] = float64(1) / float64(l)
		}
	} else {
		for i := 0; i < l; i++ {
			result[i] = float64(nums[i]) / float64(sum)
		}
	}
	return result
}

func (dtc {{ class_name }}) predict(features []float64, node int) int {
	if dtc.Thresholds[node] != -2 {
		if features[dtc.Indices[node]] <= dtc.Thresholds[node] {
			return dtc.predict(features, dtc.Lefts[node])
		} else {
			return dtc.predict(features, dtc.Rights[node])
		}
	}
	return findMax(dtc.Classes[node])
}

func (dtc {{ class_name }}) Predict(features []float64) int {
	return dtc.predict(features, 0)
}

func (dtc {{ class_name }}) predictProba(features []float64, node int) []float64 {
	if dtc.Thresholds[node] != -2 {
		if features[dtc.Indices[node]] <= dtc.Thresholds[node] {
			return dtc.predictProba(features, dtc.Lefts[node])
		} else {
			return dtc.predictProba(features, dtc.Rights[node])
		}
	}
	return normVals(dtc.Classes[node])
}

func (dtc {{ class_name }}) PredictProba(features []float64) []float64 {
	return dtc.predictProba(features, 0)
}

func main() {

	// Features:
	var features []float64
	for _, arg := range os.Args[2:] {
		if n, err := strconv.ParseFloat(arg, 64); err == nil {
			features = append(features, n)
		}
	}

	// Model data:
	jsonFile, err := os.Open(os.Args[1])
	if err != nil {
		fmt.Println(err)
	}
	defer jsonFile.Close()
	byteValue, _ := ioutil.ReadAll(jsonFile)

	// Estimator:
	var clf {{ class_name }}
	json.Unmarshal(byteValue, &clf)

	{% if is_test or to_json %}
	// Get JSON:
	prediction := clf.Predict(features)
	probabilities := clf.PredictProba(features)
	res, _ := json.Marshal(&Response{Predict: prediction, PredictProba: probabilities})
	fmt.Println(string(res))
	{% else %}
	// Get class prediction:
	prediction := clf.Predict(features)
	fmt.Printf("Predicted class: #%d\n", prediction)
	probabilities := clf.PredictProba(features)
	for i, probability := range probabilities {
		fmt.Printf("Probability of class #%d : %f\n", i, probability)
	}
	{% endif %}

}
{% endblock %}
